{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "59ef9a30-174d-4ea3-b4a0-7c465e4c3f53",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Need to install einops and timm for original omnivore model, and matplotlib for visualization\n",
    "! pip install einops timm matplotlib\n",
    "\n",
    "import torch\n",
    "import torchvision.transforms as T\n",
    "import torchmultimodal.models.omnivore as omnivore\n",
    "\n",
    "from PIL import Image\n",
    "import collections\n",
    "import json\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "55e221ba-528e-4f09-9d49-d0f54a577bc0",
   "metadata": {},
   "outputs": [],
   "source": [
    "def count_parameters(model):\n",
    "    return sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
    "\n",
    "def custom_load_state_dict(model, pretrained_state_dict):\n",
    "    # Convert the pretrained_state_dict so it have the same keys as the model\n",
    "    # then load the value of the weight into the model\n",
    "    pretrained_keys = list(pretrained_state_dict.keys())\n",
    "    model_keys = list(model.state_dict().keys())\n",
    "    key_mapping = {pretrained_keys[i]: model_keys[i] for i in range(len(model_keys))}\n",
    "    updated_pretrained_state_dict = collections.OrderedDict({key_mapping[key]: val for key, val in pretrained_state_dict.items()})\n",
    "    model.load_state_dict(updated_pretrained_state_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "67b998df-4c99-4f18-a408-a1feef5c483a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load model from torch_hub\n",
    "\n",
    "mhub = torch.hub.load(\"facebookresearch/omnivore:main\", model=\"omnivore_swinT\")\n",
    "mhub.eval()\n",
    "print(count_parameters(mhub))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5d3de763-8be6-450d-958b-0a01d4dc8b3b",
   "metadata": {},
   "outputs": [],
   "source": [
    "m = omnivore.omnivore_swin_t()\n",
    "\n",
    "# Check that it have same number of parameter\n",
    "print(count_parameters(m))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1168d33-0d11-46df-bef4-51f1973c2f95",
   "metadata": {},
   "outputs": [],
   "source": [
    "custom_load_state_dict(m, mhub.state_dict())\n",
    "m = m.eval()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f1d192ad-4902-49e1-baf6-77e7474a33bf",
   "metadata": {},
   "source": [
    "# Inference test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe6345f3-97b5-46cd-b068-aea129aedde4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Download imagenet class and image\n",
    "# Uncomment to download\n",
    "!wget https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json -O imagenet_class_index.json\n",
    "with open(\"imagenet_class_index.json\", \"r\") as f:\n",
    "    imagenet_classnames = json.load(f)\n",
    "\n",
    "# Create an id to label name mapping\n",
    "imagenet_id_to_classname = {}\n",
    "for k, v in imagenet_classnames.items():\n",
    "    imagenet_id_to_classname[k] = v[1] \n",
    "\n",
    "# Download the example image file\n",
    "# Uncomment to download\n",
    "!wget -O library.jpg https://upload.wikimedia.org/wikipedia/commons/thumb/c/c5/13-11-02-olb-by-RalfR-03.jpg/800px-13-11-02-olb-by-RalfR-03.jpg\n",
    "\n",
    "image_path = \"library.jpg\"\n",
    "image_pil = Image.open(image_path).convert(\"RGB\")\n",
    "plt.figure(figsize=(6, 6))\n",
    "plt.imshow(image_pil)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "19d91191-d461-4ae7-93cf-f8eb325133be",
   "metadata": {},
   "outputs": [],
   "source": [
    "image_transform = T.Compose(\n",
    "    [\n",
    "        T.Resize(224),\n",
    "        T.CenterCrop(224),\n",
    "        T.ToTensor(),\n",
    "        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n",
    "    ]\n",
    ")\n",
    "image = image_transform(image_pil)  # C H W\n",
    "\n",
    "# Adding batch and time (D) dimension\n",
    "image = image.unsqueeze(0).unsqueeze(2)  # B C D H W"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "22a4b65b-e9ce-48e3-9f1b-e79acb5a38f5",
   "metadata": {},
   "outputs": [],
   "source": [
    "def infer(model):\n",
    "    with torch.no_grad():\n",
    "        prediction = model(image, input_type=\"image\")\n",
    "        pred_classes = prediction.topk(k=5).indices\n",
    "\n",
    "    pred_class_names = [imagenet_id_to_classname[str(i.item())] for i in pred_classes[0]]\n",
    "    print(\"Top 5 predicted labels: %s\" % \", \".join(pred_class_names))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b4305dd1-64ef-4a0f-b817-f5efe5f23980",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Test both model to infer the same image and make sure the output classes are the same\n",
    "infer(m)\n",
    "infer(mhub)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "779c2578-1cba-478c-8909-f47330e1b376",
   "metadata": {},
   "source": [
    "# Make sure the output of the trunk / encoder are the same"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2d85f5d7-f77d-47d3-8bec-3bfa51687953",
   "metadata": {},
   "outputs": [],
   "source": [
    "m_feature = m.encoder(image)\n",
    "mhub_feature = mhub.trunk(image)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b5caf4aa-b964-4969-b531-318b9721bf17",
   "metadata": {},
   "outputs": [],
   "source": [
    "# See the first 10 features are the same\n",
    "m_feature.flatten()[:10], mhub_feature[0].flatten()[:10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ebd36e03-c86a-4d4f-a83a-b8ff3bd757b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Make sure all the features are the same\n",
    "np.all(np.array(m_feature == mhub_feature[0]))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "08adda52-d593-4d58-816a-4a6a6205ce3d",
   "metadata": {},
   "source": [
    "# Test on randomly generated input"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1dcacbc6-5cbc-4e01-b0c9-8404a33f13da",
   "metadata": {},
   "outputs": [],
   "source": [
    "mock_video = torch.randn(1, 3, 10, 112, 112)\n",
    "\n",
    "m_output = m(mock_video, input_type=\"video\")\n",
    "mhub_output = mhub(mock_video, input_type=\"video\")\n",
    "\n",
    "np.all(np.array(m_output == mhub_output[0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "357fa8cd-a853-4b12-ad87-44dba43c133b",
   "metadata": {},
   "outputs": [],
   "source": [
    "mock_depth = torch.randn(1, 4, 1, 112, 112)\n",
    "\n",
    "m_output = m(mock_video, input_type=\"rgbd\")\n",
    "mhub_output = mhub(mock_video, input_type=\"rgbd\")\n",
    "\n",
    "np.all(np.array(m_output == mhub_output[0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6bc03203-430f-4ab9-86a0-acbfadc89f67",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
